# Kernel: hpool_max

# Copyright 2014 Nervana Systems Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


<CONSTANT_MAPPING>
    param_I[0]      : c[0x0][0x140]
    param_I[1]      : c[0x0][0x144]
    param_O[0]      : c[0x0][0x148]
    param_O[1]      : c[0x0][0x14c]
    param_B[0]      : c[0x0][0x150]
    param_B[1]      : c[0x0][0x154]
    param_mode      : c[0x0][0x158]
    param_N         : c[0x0][0x15c]
    param_W         : c[0x0][0x160]
    param_H         : c[0x0][0x164]
    param_D         : c[0x0][0x168]
    param_C         : c[0x0][0x16c]
    param_WN        : c[0x0][0x170]
    param_HWN       : c[0x0][0x174]
    param_DHWN      : c[0x0][0x178]
    param_P         : c[0x0][0x17c]
    param_magic_P   : c[0x0][0x180]
    param_shift_P   : c[0x0][0x184]
    param_QN        : c[0x0][0x188]
    param_PQN       : c[0x0][0x18c]
    param_MPQN      : c[0x0][0x190]
    param_pad_j     : c[0x0][0x194]
    param_pad_d     : c[0x0][0x198]
    param_pad_h     : c[0x0][0x19c]
    param_pad_w     : c[0x0][0x1a0]
    param_str_j     : c[0x0][0x1a4]
    param_str_d     : c[0x0][0x1a8]
    param_str_h     : c[0x0][0x1ac]
    param_str_w     : c[0x0][0x1b0]
    param_S         : c[0x0][0x1b4]
    param_RS        : c[0x0][0x1b8]
    param_RST       : c[0x0][0x1bc]
    param_JRST      : c[0x0][0x1c0]
    param_magic_S   : c[0x0][0x1c4]
    param_shift_S   : c[0x0][0x1c8]
    param_magic_RS  : c[0x0][0x1cc]
    param_shift_RS  : c[0x0][0x1d0]
    param_magic_RST : c[0x0][0x1d4]
    param_shift_RST : c[0x0][0x1d8]
    param_overlap   : c[0x0][0x1dc]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>

    0-7   : Out<0-1>, In<0-1>, Back<0-1>, bprop

    8-40  : load0In<0-1>, load1In<0-1>, load2In<0-1>, load3In<0-1>, inOffset<0-3>, jrst_<0-3>, load<0-3>, max, maxIdx, lutOffset

    8-13  ~ tid, m, p, q, k, N

    14-40 ~ mp, magic_P, P, QN, PQN, MPQN, JRST, iOut

    14-63 ~ DHWN, HWN, WN, str_j, str_d, str_h, str_w, pad_j, pad_d, pad_h, pad_w, j, t, r, s, kj, mt, pr, qs, x, y, z, c, i, x0, xW, y0, yH, z0, zD, c0, cC, lutStore, jrst, rst, rs, warp_count, RST, RS, S, magic_S, magic_RS, magic_RST

</REGISTER_MAPPING>

--:-:1:-:1      S2R tid,  SR_TID.X; // tid==n
--:-:2:-:1      S2R q,  SR_CTAID.X;
--:-:3:-:1      S2R mp, SR_CTAID.Y;
--:-:4:-:1      S2R k,  SR_CTAID.Z;

<SCHEDULE_BLOCK>
01:-:-:-:1      ISETP.GE.AND P6, PT, tid, 32, PT;

// P5 = bprop mode
--:-:-:-:1      ISETP.NE.AND P5, PT, RZ, param_mode, PT;

--:-:-:-:1      MOV magic_P, param_magic_P;
--:-:-:-:1      MOV N,    param_N;
--:-:-:-:1      MOV P,    param_P;
--:-:-:-:1      MOV QN,   param_QN;
--:-:-:-:1      MOV PQN,  param_PQN;
--:-:-:-:1      MOV MPQN, param_MPQN;
--:-:-:-:1      MOV JRST, param_JRST;

// m = mp / P
// p = mp % P
04:-:-:-:1      XMAD.LO2 m, magic_P, mp, RZ;
--:-:-:-:1      SHR.U32 m, m, param_shift_P;
--:-:-:-:1      VMAD.U16.U16 p, -m, P, mp;

// I += n
--:-:-:-:1      LEA      In0.CC, tid, param_I[0],     1;
--:-:-:-:1      LEA.HI.X In1,    tid, param_I[1], RZ, 1;

--:-:-:-:1      LEA      Back0.CC, tid, param_B[0],     1;
--:-:-:-:1      LEA.HI.X Back1,    tid, param_B[1], RZ, 1;

// iOut = k*MPQN + m*PQN + p*QN + q*N + n
02:-:-:-:1      XMAD     iOut,    N, q, tid;
--:-:-:-:1      XMAD.LO2 iOut,   QN, p, iOut;
--:-:-:-:1      XMAD.LO2 iOut,  PQN, m, iOut;
08:-:-:-:1      XMAD.LO2 iOut, MPQN, k, iOut;
--:-:-:-:1      LEA      Out0.CC, iOut, param_O[0],     1;
--:-:-:-:1      LEA.HI.X Out1,    iOut, param_O[1], RZ, 1;

--:-:5:-:1  @P5 LDG.E.CI.U16 bprop, [Out];

--:-:-:-:1      ISETP.GT.AND P0, PT, JRST, RZ, PT;
--:-:-:-:1      ISETP.GT.AND P1, PT, JRST, 1,  PT;
--:-:-:-:1      ISETP.GT.AND P2, PT, JRST, 2,  PT;
--:-:-:-:1      ISETP.GT.AND P3, PT, JRST, 3,  PT;

</SCHEDULE_BLOCK>

--:-:-:-:5  @P6 BRA.U END_SETUP;

<SCHEDULE_BLOCK>
--:-:-:-:1      MOV warp_count, 32;
--:-:-:-:1      MOV jrst,       tid;
--:-:-:-:1      MOV magic_RST,  param_magic_RST;
--:-:-:-:1      MOV magic_RS,   param_magic_RS;
--:-:-:-:1      MOV magic_S,    param_magic_S;
--:-:-:-:1      MOV pad_j,      param_pad_j;
--:-:-:-:1      MOV pad_d,      param_pad_d;
--:-:-:-:1      MOV pad_h,      param_pad_h;
--:-:-:-:1      MOV pad_w,      param_pad_w;
--:-:-:-:1      MOV str_j,      param_str_j;
--:-:-:-:1      MOV str_d,      param_str_d;
--:-:-:-:1      MOV str_h,      param_str_h;
--:-:-:-:1      MOV str_w,      param_str_w;
--:-:-:-:1      MOV RST,        param_RST;
--:-:-:-:1      MOV RS,         param_RS;
--:-:-:-:1      MOV S,          param_S;
--:-:-:-:1      MOV WN,         param_WN;
--:-:-:-:1      MOV HWN,        param_HWN;
--:-:-:-:1      MOV DHWN,       param_DHWN;

// kj = k * str_j - pad_j
// mt = m * str_d - pad_d
// pr = p * str_h - pad_h
// qs = q * str_w - pad_w
--:-:-:-:1      VMAD.U16.U16 kj, k, str_j, -pad_j;
--:-:-:-:1      VMAD.U16.U16 mt, m, str_d, -pad_d;
--:-:-:-:1      VMAD.U16.U16 pr, p, str_h, -pad_h;
--:-:-:-:1      VMAD.U16.U16 qs, q, str_w, -pad_w;
</SCHEDULE_BLOCK>

LUT_LOOP:

<SCHEDULE_BLOCK>
// warp synchronous loop while warp_count < JRST
--:-:-:-:1      ISETP.LT.AND P4, PT, warp_count, param_JRST, PT;
--:-:-:-:1      IADD warp_count, warp_count, 32;
// j   = jrst / RST
// rst = jrst % RST
--:-:-:-:1      XMAD.LO2 j, magic_RST, jrst, RZ;
--:-:-:-:1      SHR.U32 j, j, param_shift_RST;
--:-:-:-:1      VMAD.U16.U16 rst, -j, RST, jrst;
// t =  rst / RS
// rs = rst % RS
--:-:-:-:1      XMAD.LO2 t, magic_RS, rst, RZ;
--:-:-:-:1      SHR.U32 t, t, param_shift_RS;
--:-:-:-:1      VMAD.U16.U16 rs, -t, RS, rst;
// r = rs / S
// s = rs % S
--:-:-:-:1      XMAD.LO2 r, magic_S, rs, RZ;
--:-:-:-:1      SHR.U32 r, r, param_shift_S;
--:-:-:-:1      VMAD.U16.U16 s, -r, S, rs;
// x = qs + s
// y = pr + r
// z = mt + t
// c = kj + j
--:-:-:-:1      IADD x, qs, s;
--:-:-:-:1      IADD y, pr, r;
--:-:-:-:1      IADD z, mt, t;
--:-:-:-:1      IADD c, kj, j;
// i = (c*DHWN + z*HWN + y*WN + x*N) * 2
01:-:-:-:1      XMAD     i, N,    x, RZ;
--:-:-:-:1      XMAD.LO2 i, WN,   y, i;
--:-:-:-:1      XMAD.LO2 i, HWN,  z, i;
--:-:-:-:1      XMAD.LO2 i, DHWN, c, i;
--:-:-:-:1      SHL i, i, 1;
// Bounds check x,y,z,c and make i negative if outside
--:-:-:-:1      ISET.LT.AND x0, x, RZ, PT;
--:-:-:-:1      ISET.GE.AND xW, x, param_W, PT;
--:-:-:-:1      ISET.LT.AND y0, y, RZ, PT;
--:-:-:-:1      ISET.GE.AND yH, y, param_H, PT;
--:-:-:-:1      ISET.LT.AND z0, z, RZ, PT;
--:-:-:-:1      ISET.GE.AND zD, z, param_D, PT;
--:-:-:-:1      ISET.LT.AND c0, c, RZ, PT;
--:-:-:-:1      ISET.GE.AND cC, c, param_C, PT;
--:-:-:-:1      LOP3.LUT i, i, x0, xW, 0xfe;
--:-:-:-:1      LOP3.LUT i, i, y0, yH, 0xfe;
<ORDERED>
--:-:-:-:1      LOP3.LUT i, i, z0, zD, 0xfe;
--:-:-:-:1      SHL lutStore, jrst, 2;
--:-:-:-:1      IADD jrst, jrst, 32;
</ORDERED>
--:-:-:-:1      LOP3.LUT i, i, c0, cC, 0xfe;
// Store i imgOffset into the shared lookup table
--:1:-:-:1      STS [lutStore], i;
</SCHEDULE_BLOCK>

--:-:-:-:5  @P4 BRA.U LUT_LOOP;

END_SETUP:

--:-:-:-:1      ISETP.NE.AND P4, PT, RZ, param_overlap, PT;
--:-:-:-:1      MOV32I max, 0xff7fffff; // -FLT_MAX
--:-:-:-:1      MOV maxIdx, RZ;
--:-:-:-:1      MOV lutOffset, RZ;
--:-:-:-:1      MOV jrst_0, RZ;
--:-:-:-:1      MOV jrst_1, 1;
--:-:-:-:1      MOV jrst_2, 2;
--:-:-:-:0      MOV jrst_3, 3;

--:-:-:-:5      BAR.SYNC 0;

LOOP:

--:-:1:-:2      LDS.U.128 inOffset, [lutOffset];

01:-:-:-:2      ISETP.GE.AND P0, PT, inOffset0, RZ, P0;
--:-:-:-:2      ISETP.GE.AND P1, PT, inOffset1, RZ, P1;
--:-:-:-:2      ISETP.GE.AND P2, PT, inOffset2, RZ, P2;
--:-:-:-:1      ISETP.GE.AND P3, PT, inOffset3, RZ, P3;

--:-:-:-:6      IADD   load0In0.CC, In0, inOffset0;
--:-:-:-:2      IADD.X load0In1,    In1, RZ;
--:-:1:-:1  @P0 LDG.E.CI.U16 load0, [load0In];

--:-:-:-:6  @P1 IADD   load1In0.CC, In0, inOffset1;
--:-:-:-:2  @P1 IADD.X load1In1,    In1, RZ;
--:-:2:-:1  @P1 LDG.E.CI.U16 load1, [load1In];

--:-:-:-:6  @P2 IADD   load2In0.CC, In0, inOffset2;
--:-:-:-:2  @P2 IADD.X load2In1,    In1, RZ;
--:-:3:-:1  @P2 LDG.E.CI.U16 load2, [load2In];

--:-:-:-:6  @P3 IADD   load3In0.CC, In0, inOffset3;
--:-:-:-:2  @P3 IADD.X load3In1,    In1, RZ;
--:-:4:-:1  @P3 LDG.E.CI.U16 load3, [load3In];

--:-:-:-:1      IADD jrst_0, jrst_0, 4;
--:-:-:-:1      IADD jrst_1, jrst_1, 4;
--:-:-:-:1      IADD jrst_2, jrst_2, 4;
--:-:-:-:1      IADD jrst_3, jrst_3, 4;
--:-:-:-:0      IADD lutOffset, lutOffset, 16;

01:-:1:-:4  @P0 F2F.F32.F16 load0, load0;
02:-:2:-:4  @P1 F2F.F32.F16 load1, load1;
04:-:3:-:4  @P2 F2F.F32.F16 load2, load2;
08:-:4:-:1  @P3 F2F.F32.F16 load3, load3;

01:-:-:Y:d      FSETP.GT.AND P0, PT, load0, max, P0;
--:-:-:-:1  @P0 MOV max,    load0;
--:-:-:-:1  @P0 MOV maxIdx, inOffset0;
--:-:-:-:4      ISETP.LT.AND P0, PT, jrst_0, param_JRST, PT;

02:-:-:Y:d      FSETP.GT.AND P1, PT, load1, max, P1;
--:-:-:-:1  @P1 MOV max,    load1;
--:-:-:-:1  @P1 MOV maxIdx, inOffset1;
--:-:-:-:4      ISETP.LT.AND P1, PT, jrst_1, param_JRST, PT;

04:-:-:Y:d      FSETP.GT.AND P2, PT, load2, max, P2;
--:-:-:-:1  @P2 MOV max,    load2;
--:-:-:-:1  @P2 MOV maxIdx, inOffset2;
--:-:-:-:4      ISETP.LT.AND P2, PT, jrst_2, param_JRST, PT;

08:-:-:Y:d      FSETP.GT.AND P3, PT, load3, max, P3;
--:-:-:-:1  @P3 MOV max,    load3;
--:-:-:-:1  @P3 MOV maxIdx, inOffset3;
--:-:-:-:0      ISETP.LT.AND P3, PT, jrst_3, param_JRST, PT;

--:-:-:-:5  @P0 BRA.U LOOP;

--:-:1:-:1 @!P5 F2F.F16.F32 max, max;

// P0 = Back0 & 2
--:-:-:-:1      LOP.AND.NZ P0, RZ, Back0, 0x2;

// fprop
--:-:-:-:5  @P5 BRA.U BPROP;

01:-:-:-:1      STG.E.CG.U16 [Out], max;
--:-:-:-:5      BRA.U EXIT;

BPROP:

// bprop
--:-:-:-:6      IADD   Back0.CC, Back0, maxIdx;
--:-:-:-:0      IADD.X Back1,    Back1, RZ;

--:-:-:-:5  @P4 BRA.U OVERLAP;

--:-:-:-:1      STG.E.CG.U16 [Back0], bprop;

--:-:-:-:5      BRA.U EXIT;

OVERLAP:

--:-:-:-:1 @!P0 LOP32I.AND bprop, bprop, 0x0000ffff;
--:-:-:-:1  @P0 SHL        bprop, bprop, 16;
--:-:-:-:2  @P0 LOP32I.AND Back0, Back0, 0xfffffffc;

--:-:-:-:1      RED.E.ADD.F16x2.FTZ.RN [Back0], bprop;


EXIT:

--:-:-:-:5      EXIT;
